{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "etTmZVFN8fYO"
      },
      "source": [
        "This notebook runs a basic speed test for a short training loop of a neural network training on the MNIST dataset."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "eqOvRhOz8SWs"
      },
      "source": [
        "### Imports"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {
          "autoexec": {
            "startup": false,
            "wait_interval": 0
          }
        },
        "colab_type": "code",
        "id": "nHY0tntRizGb"
      },
      "outputs": [],
      "source": [
        "!pip install -U -q tf-nightly"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {
          "autoexec": {
            "startup": false,
            "wait_interval": 0
          }
        },
        "colab_type": "code",
        "id": "Pa2qpEmoVOGe"
      },
      "outputs": [],
      "source": [
        "import gzip\n",
        "import os\n",
        "import shutil\n",
        "import time\n",
        "\n",
        "import numpy as np\n",
        "import six\n",
        "from six.moves import urllib\n",
        "import tensorflow as tf\n",
        "\n",
        "from tensorflow.contrib import autograph as ag\n",
        "from tensorflow.contrib.eager.python import tfe\n",
        "from tensorflow.python.eager import context\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "PZWxEJFM9A7b"
      },
      "source": [
        "### Testing boilerplate"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {
          "autoexec": {
            "startup": false,
            "wait_interval": 0
          }
        },
        "colab_type": "code",
        "id": "kfZk9EFZ5TeQ"
      },
      "outputs": [],
      "source": [
        "# Test-only parameters. Test checks successful completion not correctness. \n",
        "burn_ins = 1\n",
        "trials = 1\n",
        "max_steps = 2\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "k0GKbZBJ9Gt9"
      },
      "source": [
        "### Speed test configuration"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {
          "autoexec": {
            "startup": false,
            "wait_interval": 0
          }
        },
        "colab_type": "code",
        "id": "gWXV8WHn43iZ"
      },
      "outputs": [],
      "source": [
        "#@test {\"skip\": true} \n",
        "burn_ins = 3\n",
        "trials = 10\n",
        "max_steps = 500\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "kZV_3pGy8033"
      },
      "source": [
        "### Data source setup"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {
          "autoexec": {
            "startup": false,
            "wait_interval": 0
          }
        },
        "colab_type": "code",
        "id": "YfnHJbBOBKae"
      },
      "outputs": [],
      "source": [
        "def download(directory, filename):\n",
        "  filepath = os.path.join(directory, filename)\n",
        "  if tf.gfile.Exists(filepath):\n",
        "    return filepath\n",
        "  if not tf.gfile.Exists(directory):\n",
        "    tf.gfile.MakeDirs(directory)\n",
        "  url = 'https://storage.googleapis.com/cvdf-datasets/mnist/' + filename + '.gz'\n",
        "  zipped_filepath = filepath + '.gz'\n",
        "  print('Downloading %s to %s' % (url, zipped_filepath))\n",
        "  urllib.request.urlretrieve(url, zipped_filepath)\n",
        "  with gzip.open(zipped_filepath, 'rb') as f_in, open(filepath, 'wb') as f_out:\n",
        "    shutil.copyfileobj(f_in, f_out)\n",
        "  os.remove(zipped_filepath)\n",
        "  return filepath\n",
        "\n",
        "\n",
        "def dataset(directory, images_file, labels_file):\n",
        "  images_file = download(directory, images_file)\n",
        "  labels_file = download(directory, labels_file)\n",
        "\n",
        "  def decode_image(image):\n",
        "    # Normalize from [0, 255] to [0.0, 1.0]\n",
        "    image = tf.decode_raw(image, tf.uint8)\n",
        "    image = tf.cast(image, tf.float32)\n",
        "    image = tf.reshape(image, [784])\n",
        "    return image / 255.0\n",
        "\n",
        "  def decode_label(label):\n",
        "    label = tf.decode_raw(label, tf.uint8)\n",
        "    label = tf.reshape(label, [])\n",
        "    return tf.to_int32(label)\n",
        "\n",
        "  images = tf.data.FixedLengthRecordDataset(\n",
        "      images_file, 28 * 28, header_bytes=16).map(decode_image)\n",
        "  labels = tf.data.FixedLengthRecordDataset(\n",
        "      labels_file, 1, header_bytes=8).map(decode_label)\n",
        "  return tf.data.Dataset.zip((images, labels))\n",
        "\n",
        "\n",
        "def mnist_train(directory):\n",
        "  return dataset(directory, 'train-images-idx3-ubyte',\n",
        "                 'train-labels-idx1-ubyte')\n",
        "\n",
        "def mnist_test(directory):\n",
        "  return dataset(directory, 't10k-images-idx3-ubyte', 't10k-labels-idx1-ubyte')\n",
        "\n",
        "def setup_mnist_data(is_training, hp, batch_size):\n",
        "  if is_training:\n",
        "    ds = mnist_train('/tmp/autograph_mnist_data')\n",
        "    ds = ds.cache()\n",
        "    ds = ds.shuffle(batch_size * 10)\n",
        "  else:\n",
        "    ds = mnist_test('/tmp/autograph_mnist_data')\n",
        "    ds = ds.cache()\n",
        "  ds = ds.repeat()\n",
        "  ds = ds.batch(batch_size)\n",
        "  return ds\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "qzkZyZcS9THu"
      },
      "source": [
        "### Keras model definition"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {
          "autoexec": {
            "startup": false,
            "wait_interval": 0
          }
        },
        "colab_type": "code",
        "id": "x_MU13boiok2"
      },
      "outputs": [],
      "source": [
        "def mlp_model(input_shape):\n",
        "  model = tf.keras.Sequential((\n",
        "      tf.keras.layers.Dense(100, activation='relu', input_shape=input_shape),\n",
        "      tf.keras.layers.Dense(100, activation='relu'),\n",
        "      tf.keras.layers.Dense(10, activation='softmax')))\n",
        "  model.build()\n",
        "  return model\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "DXt4GoTxtvn2"
      },
      "source": [
        "# AutoGraph"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {
          "autoexec": {
            "startup": false,
            "wait_interval": 0
          }
        },
        "colab_type": "code",
        "id": "W51sfbONiz_5"
      },
      "outputs": [],
      "source": [
        "def predict(m, x, y):\n",
        "  y_p = m(x)\n",
        "  losses = tf.keras.losses.categorical_crossentropy(y, y_p)\n",
        "  l = tf.reduce_mean(losses)\n",
        "  accuracies = tf.keras.metrics.categorical_accuracy(y, y_p)\n",
        "  accuracy = tf.reduce_mean(accuracies)\n",
        "  return l, accuracy\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {
          "autoexec": {
            "startup": false,
            "wait_interval": 0
          }
        },
        "colab_type": "code",
        "id": "CsAD0ajbi9iZ"
      },
      "outputs": [],
      "source": [
        "def fit(m, x, y, opt):\n",
        "  l, accuracy = predict(m, x, y)\n",
        "  opt.minimize(l)\n",
        "  return l, accuracy\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {
          "autoexec": {
            "startup": false,
            "wait_interval": 0
          }
        },
        "colab_type": "code",
        "id": "RVw57HdTjPzi"
      },
      "outputs": [],
      "source": [
        "def get_next_batch(ds):\n",
        "  itr = ds.make_one_shot_iterator()\n",
        "  image, label = itr.get_next()\n",
        "  x = tf.to_float(tf.reshape(image, (-1, 28 * 28)))\n",
        "  y = tf.one_hot(tf.squeeze(label), 10)\n",
        "  return x, y\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {
          "autoexec": {
            "startup": false,
            "wait_interval": 0
          }
        },
        "colab_type": "code",
        "id": "UUI0566FjZPx"
      },
      "outputs": [],
      "source": [
        "def train(train_ds, test_ds, hp):\n",
        "  m = mlp_model((28 * 28,))\n",
        "  opt = tf.train.MomentumOptimizer(hp.learning_rate, 0.9)\n",
        "\n",
        "  train_losses = []\n",
        "  test_losses = []\n",
        "  train_accuracies = []\n",
        "  test_accuracies = []\n",
        "  ag.set_element_type(train_losses, tf.float32)\n",
        "  ag.set_element_type(test_losses, tf.float32)\n",
        "  ag.set_element_type(train_accuracies, tf.float32)\n",
        "  ag.set_element_type(test_accuracies, tf.float32)\n",
        "\n",
        "  i = tf.constant(0)\n",
        "  while i \u003c hp.max_steps:\n",
        "    train_x, train_y = get_next_batch(train_ds)\n",
        "    test_x, test_y = get_next_batch(test_ds)\n",
        "    step_train_loss, step_train_accuracy = fit(m, train_x, train_y, opt)\n",
        "    step_test_loss, step_test_accuracy = predict(m, test_x, test_y)\n",
        "\n",
        "    train_losses.append(step_train_loss)\n",
        "    test_losses.append(step_test_loss)\n",
        "    train_accuracies.append(step_train_accuracy)\n",
        "    test_accuracies.append(step_test_accuracy)\n",
        "\n",
        "    i += 1\n",
        "  return (ag.stack(train_losses), ag.stack(test_losses),\n",
        "          ag.stack(train_accuracies), ag.stack(test_accuracies))\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {
          "autoexec": {
            "startup": false,
            "wait_interval": 0
          },
          "height": 215
        },
        "colab_type": "code",
        "executionInfo": {
          "elapsed": 12156,
          "status": "ok",
          "timestamp": 1531752050611,
          "user": {
            "displayName": "",
            "photoUrl": "",
            "userId": ""
          },
          "user_tz": 240
        },
        "id": "K1m8TwOKjdNd",
        "outputId": "bd5746f2-bf91-44aa-9eff-38eb11ced33f"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "('Duration:', 0.6226680278778076)\n",
            "('Duration:', 0.6082069873809814)\n",
            "('Duration:', 0.6223258972167969)\n",
            "('Duration:', 0.6176440715789795)\n",
            "('Duration:', 0.6309840679168701)\n",
            "('Duration:', 0.6180410385131836)\n",
            "('Duration:', 0.6219630241394043)\n",
            "('Duration:', 0.6183009147644043)\n",
            "('Duration:', 0.6176400184631348)\n",
            "('Duration:', 0.6476900577545166)\n",
            "('Mean duration:', 0.62254641056060789, '+/-', 0.0099792188690656976)\n"
          ]
        }
      ],
      "source": [
        "#@test {\"timeout\": 90}\n",
        "with tf.Graph().as_default():\n",
        "  hp = tf.contrib.training.HParams(\n",
        "      learning_rate=0.05,\n",
        "      max_steps=max_steps,\n",
        "  )\n",
        "  train_ds = setup_mnist_data(True, hp, 500)\n",
        "  test_ds = setup_mnist_data(False, hp, 100)\n",
        "  tf_train = ag.to_graph(train)\n",
        "  losses = tf_train(train_ds, test_ds, hp)\n",
        "\n",
        "  with tf.Session() as sess:\n",
        "    durations = []\n",
        "    for t in range(burn_ins + trials):\n",
        "      sess.run(tf.global_variables_initializer())\n",
        "\n",
        "      start = time.time()\n",
        "      (train_losses, test_losses, train_accuracies,\n",
        "       test_accuracies) = sess.run(losses)\n",
        "\n",
        "      if t \u003c burn_ins:\n",
        "        continue\n",
        "\n",
        "      duration = time.time() - start\n",
        "      durations.append(duration)\n",
        "      print('Duration:', duration)\n",
        "\n",
        "    print('Mean duration:', np.mean(durations), '+/-', np.std(durations))\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "A06kdgtZtlce"
      },
      "source": [
        "# Eager"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {
          "autoexec": {
            "startup": false,
            "wait_interval": 0
          }
        },
        "colab_type": "code",
        "id": "hBKOKGrWty4e"
      },
      "outputs": [],
      "source": [
        "def predict(m, x, y):\n",
        "  y_p = m(x)\n",
        "  losses = tf.keras.losses.categorical_crossentropy(tf.cast(y, tf.float32), y_p)\n",
        "  l = tf.reduce_mean(losses)\n",
        "  accuracies = tf.keras.metrics.categorical_accuracy(y, y_p)\n",
        "  accuracy = tf.reduce_mean(accuracies)\n",
        "  return l, accuracy\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {
          "autoexec": {
            "startup": false,
            "wait_interval": 0
          }
        },
        "colab_type": "code",
        "id": "HCgTZ0MTt6vt"
      },
      "outputs": [],
      "source": [
        "def train(ds, hp):\n",
        "  m = mlp_model((28 * 28,))\n",
        "  opt = tf.train.MomentumOptimizer(hp.learning_rate, 0.9)\n",
        "\n",
        "  train_losses = []\n",
        "  test_losses = []\n",
        "  train_accuracies = []\n",
        "  test_accuracies = []\n",
        "\n",
        "  i = 0\n",
        "  train_test_itr = tfe.Iterator(ds)\n",
        "  for (train_x, train_y), (test_x, test_y) in train_test_itr:\n",
        "    train_x = tf.to_float(tf.reshape(train_x, (-1, 28 * 28)))\n",
        "    train_y = tf.one_hot(tf.squeeze(train_y), 10)\n",
        "    test_x = tf.to_float(tf.reshape(test_x, (-1, 28 * 28)))\n",
        "    test_y = tf.one_hot(tf.squeeze(test_y), 10)\n",
        "\n",
        "    if i \u003e hp.max_steps:\n",
        "      break\n",
        "\n",
        "    with tf.GradientTape() as tape:\n",
        "      step_train_loss, step_train_accuracy = predict(m, train_x, train_y)\n",
        "    grad = tape.gradient(step_train_loss, m.variables)\n",
        "    opt.apply_gradients(zip(grad, m.variables))\n",
        "    step_test_loss, step_test_accuracy = predict(m, test_x, test_y)\n",
        "\n",
        "    train_losses.append(step_train_loss)\n",
        "    test_losses.append(step_test_loss)\n",
        "    train_accuracies.append(step_train_accuracy)\n",
        "    test_accuracies.append(step_test_accuracy)\n",
        "\n",
        "    i += 1\n",
        "  return train_losses, test_losses, train_accuracies, test_accuracies\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {
          "autoexec": {
            "startup": false,
            "wait_interval": 0
          },
          "height": 215
        },
        "colab_type": "code",
        "executionInfo": {
          "elapsed": 52499,
          "status": "ok",
          "timestamp": 1531752103279,
          "user": {
            "displayName": "",
            "photoUrl": "",
            "userId": ""
          },
          "user_tz": 240
        },
        "id": "plv_yrn_t8Dy",
        "outputId": "55d5ab3d-252d-48ba-8fb4-20ec3c3e6d00"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "('Duration:', 3.9973549842834473)\n",
            "('Duration:', 4.018772125244141)\n",
            "('Duration:', 3.9740989208221436)\n",
            "('Duration:', 3.9922947883605957)\n",
            "('Duration:', 3.9795801639556885)\n",
            "('Duration:', 3.966722011566162)\n",
            "('Duration:', 3.986541986465454)\n",
            "('Duration:', 3.992305040359497)\n",
            "('Duration:', 4.012261867523193)\n",
            "('Duration:', 4.004716157913208)\n",
            "('Mean duration:', 3.9924648046493529, '+/-', 0.015681688635624851)\n"
          ]
        }
      ],
      "source": [
        "#@test {\"timeout\": 90}\n",
        "with context.eager_mode():\n",
        "  durations = []\n",
        "  for t in range(burn_ins + trials):\n",
        "    hp = tf.contrib.training.HParams(\n",
        "        learning_rate=0.05,\n",
        "        max_steps=max_steps,\n",
        "    )\n",
        "    train_ds = setup_mnist_data(True, hp, 500)\n",
        "    test_ds = setup_mnist_data(False, hp, 100)\n",
        "    ds = tf.data.Dataset.zip((train_ds, test_ds))\n",
        "    start = time.time()\n",
        "    (train_losses, test_losses, train_accuracies,\n",
        "     test_accuracies) = train(ds, hp)\n",
        "    \n",
        "    train_losses[-1].numpy()\n",
        "    test_losses[-1].numpy()\n",
        "    train_accuracies[-1].numpy()\n",
        "    test_accuracies[-1].numpy()\n",
        "\n",
        "    if t \u003c burn_ins:\n",
        "      continue\n",
        "\n",
        "    duration = time.time() - start\n",
        "    durations.append(duration)\n",
        "    print('Duration:', duration)\n",
        "\n",
        "  print('Mean duration:', np.mean(durations), '+/-', np.std(durations))\n"
      ]
    }
  ],
  "metadata": {
    "colab": {
      "collapsed_sections": [
        "eqOvRhOz8SWs",
        "PZWxEJFM9A7b",
        "kZV_3pGy8033"
      ],
      "default_view": {},
      "name": "Autograph vs. Eager MNIST speed test",
      "provenance": [
        {
          "file_id": "1tAQW5tHUgAc8M4-iwwJm6Xs6dV9nEqtD",
          "timestamp": 1530297010607
        },
        {
          "file_id": "18dCjshrmHiPTIe1CNsL8tnpdGkuXgpM9",
          "timestamp": 1530289467317
        },
        {
          "file_id": "1DcfimonWU11tmyivKBGVrbpAl3BIOaRG",
          "timestamp": 1522272821237
        },
        {
          "file_id": "1wCZUh73zTNs1jzzYjqoxMIdaBWCdKJ2K",
          "timestamp": 1522238054357
        },
        {
          "file_id": "1_HpC-RrmIv4lNaqeoslUeWaX8zH5IXaJ",
          "timestamp": 1521743157199
        },
        {
          "file_id": "1mjO2fQ2F9hxpAzw2mnrrUkcgfb7xSGW-",
          "timestamp": 1520522344607
        }
      ],
      "version": "0.3.2",
      "views": {}
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
